#*
# Authors: Anonymous
# This file is part of OASIS library.
#
# This file is based on the AdaHessian repository
# https://github.com/amirgholami/adahessian
#*

import torch
from torch.optim import Optimizer
import math
import copy
from IPython.core.debugger import set_trace

class OASIS(Optimizer):
    r'''Implements OASIS algorithm.    
    Args:
        params (iterable): iterable of parameters to optimize or dicts defining
            parameter groups
        beta (float, optional): coefficient for updating the Hessian diagonal
            estimate
        alpha (float, optional): lower bound on the Hessian diagonal elements
        eta (float, optional): starting value for eta
        eps (float, optional): eps for numerical stability
        warmstart_samples (int, optional): number of samples used for the
            initial Hessian diagonal estimation
    '''    
    
    def __init__(self, params, device,beta=0.999, alpha=1e-5, eta=1e-5, eps=1e-8, warmstart_samples=500):
        if not 0.0 <= beta < 1.0:
            raise ValueError("Invalid beta value: {}".format(beta))
        if not 0.0 <= alpha:
            raise ValueError("Invalid alpha value: {}".format(alpha))
        if not 0.0 < eta:
            raise ValueError("Invalid eta value: {}".format(eta))
        if not 0.0 < eps:
            raise ValueError("Invalid eps value: {}".format(eta))
        if not 0.0 < warmstart_samples:
            raise ValueError("Invalid amount of warmstart samples: {}".format(warmstart_samples))
            
        defaults = dict(beta=beta, alpha=alpha, eta=eta, eps=eps)            
        super(OASIS, self).__init__(params, defaults)
        
        # flag to sample more diagonals in the beginning
        # warmstart_samples is not passed to Optimizer init, since it is shared by all parameter groups
        # when we find vhvs
        self.warmstarted = False
        self.warmstart_samples = warmstart_samples
        self.device=device
        
    def get_trace(self, params, grads, n_samples=1):
        r'''Compute the Hessian vector product with a random vector v, at the current
        gradient point, i.e., compute the gradient of <gradsH, v>.
        
        Args:
            params (iterable): a list of torch variables
            grads (iterable): a list of gradients
            n_samples (int, optional): numer of samples used to obtain
                Hessian diagonal estimate
        '''

        # check backward was called with create_graph set to True, else you can't differentiate grads
        for i, grad in enumerate(grads):
            if grad.grad_fn is None:
                raise RuntimeError('Gradient tensor {:} does not have grad_fn. When calling\n'.format(i) +
                           '\t\t\t  loss.backward(), make sure the option create_graph is\n' +
                           '\t\t\t  set to True.')

        vs = [2 * torch.randint_like(p, high=2, device=self.device) - 1 for p in params]
    
        # * is interpreted as component-wise multiplication
        hvs = torch.autograd.grad(grads,
                                  params,
                                  grad_outputs=vs,
                                  only_inputs=True,
                                  retain_graph=True)
        vhvs = [v * hv for (v, hv) in zip(vs, hvs)]
        
        # averaging samples to get a better estimate
        counter = 1
        for i in range(n_samples - 1):
            vs = [2 * torch.randint_like(p, high=2, device=self.device) - 1 for p in params]

            hvs_next = torch.autograd.grad(grads,
                                  params,
                                  grad_outputs=vs,
                                  only_inputs=True,
                                  retain_graph=True)
            vhvs_next = [v * hv_next for (v, hv_next) in zip(vs, hvs_next)]

            vhvs = [vhv * counter / (counter + 1) + vhv_next / (counter + 1) for (vhv, vhv_next) in zip(vhvs, vhvs_next)]
            counter += 1

        return vhvs
    
    def compute_dif_norms(self):
        r'''Weighted norm for torch.tensor's

        Args:
            d (sequence of Tensors): group of 1-d compatible tensors, which produces the weight, assumed
                to be a diagonal of a matrix
            x_prev (sequence of Tensors): group of 1-d compatible vectors in the vector difference
            x_cur (sequence of Tensors): group of 1-d compatible vectors in the vector difference
        '''
        for group in self.param_groups:
            # views are just in case, since matmul behavious is dependent on the 
            # dimensions of the arguments
            grad_dif_norm = 0
            param_dif_norm = 0
            
            
            # need to change it to simply params instead of separate grads and values
            for (d, x_cur, x_prev, x_grad_cur, x_grad_prev) in zip(group['exp_h_diag_avg'], 
                                                                   group['x_cur'], 
                                                                   group['x_prev'], 
                                                                   group['x_grad_cur'], 
                                                                   group['x_grad_prev']):

                grad_dif_norm += torch.sum(((x_grad_cur - x_grad_prev) ** 2) / torch.clamp(torch.abs(d), min=group['alpha']))
                param_dif_norm += torch.sum(((x_cur - x_prev) ** 2) * torch.clamp(torch.abs(d), min=group['alpha']))

            group['grad_dif_norm'] = torch.sqrt(grad_dif_norm)
            group['param_dif_norm'] = torch.sqrt(param_dif_norm)
    
    def step(self, closure=None):        
        r'''Performs a single optimization step.

        Arguments:
            closure (callable, optional): A closure that reevaluates the model
                and returns the loss.
        '''
        loss = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()
        
        # need to pass lists to get_trace() to get rid of undesirable generator behaviour
        params = []
        grads = []
        
        
        for group in self.param_groups:
            for p in group['params']:
                if p.grad is not None:
                    params.append(p)
                    grads.append(p.grad)
        
        
        # getting diagonal estimates
        if not self.warmstarted:
            h_diags = self.get_trace(params, grads, n_samples = self.warmstart_samples)
            self.warmstarted = True
        else:
            h_diags = self.get_trace(params, grads)
        
        
        for group in self.param_groups:
            # group-wide initialization of memory if there is no diagonal stored
            if 'exp_h_diag_avg' not in group:
                group['x_prev'] = None
                group['x_grad_prev'] = None
                group['exp_h_diag_avg'] = h_diags
                group['theta'] = float('Inf')
                group['lambda_prev_max'] = None
            else:
                group['x_prev'] = group['x_cur']
                group['x_grad_prev'] = group['x_grad_cur']
                # exponential moving average of h_diag
                group['exp_h_diag_avg'] = [prev_h_diag.mul_(group['beta']).add_(h_diag.mul_(1 - group['beta'])) for 
                                           (prev_h_diag, h_diag) in zip(group['exp_h_diag_avg'], h_diags)]
                group['lambda_prev_max'] = group['lambda_cur_max']
           
            group['lambda_cur_min'] = min(torch.min(d) for d in group['exp_h_diag_avg'])
            group['lambda_cur_max'] = max(torch.max(d) for d in group['exp_h_diag_avg'])
            group['x_cur'] = [p.clone().detach() for p in params]
            group['x_grad_cur'] = [p.grad.clone().detach() for p in params]
            
            
            # updating eta if we can
            if (group['x_prev'] is not None) and (group['x_grad_prev'] is not None) and (group['lambda_prev_max'] is not None):
                eta_prev = group['eta']
                self.compute_dif_norms()
                group['eta'] = min(math.sqrt(1 + group['theta']) * group['eta'],
                                   group['param_dif_norm'] / (2 * group['grad_dif_norm']))
    
                group['theta'] = group['eta'] / eta_prev
            
            
            for (p, grad, h_diag) in zip(params, grads, group['exp_h_diag_avg']):

                state = self.state[p]
                
                p.data = p.data + torch.mul(grad.detach_() / torch.clamp(torch.abs(h_diag), min=group['alpha']), -group['eta'])
        
        return loss